"""Measure Wilson loops on a 2D lattice.

This module provides a function ``measure_wilson_loops`` which computes
Wilson loop observables for square loops of various sizes on a given
lattice and link variable configuration.  It supports both U(1) and
SU(N) gauge groups and handles periodic boundary conditions.

The implementation does not depend on any external libraries beyond
NumPy.  It is a lightweight subset of the full pipeline used in other
repositories but is sufficient for validating discrete‑gauge simulations.
"""

from typing import Dict, List, Tuple
import numpy as np


def build_link_map(lattice: List[Tuple[Tuple[int, int], int]]) -> Dict[Tuple[Tuple[int, int], int], int]:
    """Build a mapping from link specification to its index in the lattice array."""
    return {tuple(link): idx for idx, link in enumerate(lattice)}


def measure_wilson_loops(
    lattice: List[Tuple[Tuple[int, int], int]],
    U: np.ndarray,
    sizes: List[int],
    bc: str = 'periodic'
) -> Dict[int, List[complex]]:
    """Compute Wilson loops for a set of square loop sizes.

    Parameters
    ----------
    lattice : list
        List of links ``[((x,y), mu), ...]`` describing the lattice geometry.
    U : np.ndarray
        Array of link variables.  For U(1) this is complex 1D; for SU(N) it
        is an array of shape ``(num_links, N, N)``.
    sizes : list of int
        Edge lengths ``L`` of the square loops to measure.
    bc : str, optional
        Boundary conditions: currently only 'periodic' is supported.

    Returns
    -------
    dict
        A mapping from loop size ``L`` to the list of complex Wilson-loop
        values measured at all possible starting positions.
    """
    # Determine lattice size from the number of unique sites
    size_grid = int(np.sqrt(len({pos for pos, mu in lattice})))
    link_map = build_link_map(lattice)
    results: Dict[int, List[complex]] = {L: [] for L in sizes}
    # Check whether U is scalar (U(1)) or matrix (SU(N)) by inspecting a sample element
    sample = U[0]
    is_scalar = np.isscalar(sample) or getattr(sample, 'shape', ()) == ()
    identity = 1 + 0j if is_scalar else np.eye(sample.shape[0], dtype=complex)
    for L in sizes:
        for x in range(size_grid):
            for y in range(size_grid):
                # Start with identity for each loop
                W = identity.copy() if not is_scalar else identity
                # Right steps (mu=0)
                for i in range(L):
                    xi, yi = (x + i) % size_grid, y
                    idx = link_map[((xi, yi), 0)]
                    W = W * U[idx] if is_scalar else W @ U[idx]
                # Up steps (mu=1)
                for j in range(L):
                    xi, yi = (x + L) % size_grid, (y + j) % size_grid
                    idx = link_map[((xi, yi), 1)]
                    W = W * U[idx] if is_scalar else W @ U[idx]
                # Left steps (mu=0, inverse)
                for i in range(L):
                    xi, yi = (x + L - 1 - i) % size_grid, (y + L) % size_grid
                    idx = link_map[((xi, yi), 0)]
                    if is_scalar:
                        W *= np.conjugate(U[idx])
                    else:
                        W = W @ np.conjugate(U[idx]).T
                # Down steps (mu=1, inverse)
                for j in range(L):
                    xi, yi = x, (y + L - 1 - j) % size_grid
                    idx = link_map[((xi, yi), 1)]
                    if is_scalar:
                        W *= np.conjugate(U[idx])
                    else:
                        W = W @ np.conjugate(U[idx]).T
                # For SU(N) take the trace; for U(1) W is already scalar
                val = W if is_scalar else np.trace(W)
                results[L].append(val)
    return results